{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from torch import nn\n",
    "from torch import optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义最高次项的指数\n",
    "\n",
    "n = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为线性相乘做准备，将x中的每一个元素排列成 [x^1, x^2, ..., x^n]\n",
    "# 因为有多个x，故返回的是一个4x4的二维张量\n",
    "\n",
    "def make_features(x) :\n",
    "    x = x.unsqueeze(1)\n",
    "    return torch.cat([x ** i for i in range(1, n + 1)], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 不加入x^0的原因是，nn.Linear默认自带bias属性（可取消）， 以下来源于http://pytorch.org/docs/master/nn.html\n",
    "# 在页面内搜索nn.Linear即可找到\n",
    "\n",
    "## class torch.nn.Linear(in_features, out_features, bias=True)[source]\n",
    "## Applies a linear transformation to the incoming data: y=Ax+b\n",
    "##\n",
    "## > Parameters:\n",
    "##    - in_features – size of each input sample\n",
    "##    - out_features – size of each output sample\n",
    "##    - bias – If set to False, the layer will not learn an additive bias. Default: True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "#定义目标函数，包括权重和偏执\n",
    "\n",
    "W_target = torch.FloatTensor([random.randint(-1000, 1000) * 0.01 for i in range(n)]).unsqueeze(1)\n",
    "b_target = torch.FloatTensor([random.randint(-100, 1000) * 0.01])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义实际函数\n",
    "\n",
    "def f(x) :\n",
    "    return x.mm(W_target) + b_target[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 生成训练集。随机取数，然后生成x，然后生成y\n",
    "\n",
    "def get_batch(batch_size = 32, random = None) :\n",
    "    if random is None :\n",
    "        random = torch.randn(batch_size)\n",
    "    batch_size = random.size()[0]\n",
    "    x = make_features(random)\n",
    "    y = f(x)\n",
    "    if torch.cuda.is_available() :\n",
    "        return Variable(x).cuda(), Variable(y).cuda()\n",
    "    else :\n",
    "        return Variable(x), Variable(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造训练网络模型\n",
    "\n",
    "class poly_model(nn.Module) :\n",
    "    def __init__(self, n) :\n",
    "        super().__init__()\n",
    "        self.poly = nn.Linear(n, 1)\n",
    "    def forward(self, x) :\n",
    "        return self.poly(x)\n",
    "\n",
    "# 实例化网络模型\n",
    "if torch.cuda.is_available() :\n",
    "    poly = poly_model(n).cuda()\n",
    "else :\n",
    "    poly = poly_model(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义损失函数为均方误差, 定义优化函数为随机梯度下降，学习率为0.03\n",
    "\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.SGD(poly.parameters(), lr = 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the number of epoches : 7788\n"
     ]
    }
   ],
   "source": [
    "# 开始训练\n",
    "\n",
    "epoch = 0\n",
    "while True :\n",
    "    # 获得数据\n",
    "    batch_x, batch_y = get_batch()\n",
    "    # 前向计算\n",
    "    output = poly(batch_x)\n",
    "    # 计算损失函数\n",
    "    loss = criterion(output, batch_y)\n",
    "    print_loss = loss.data[0]\n",
    "    # 参数更新\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    epoch += 1\n",
    "    if print_loss < 1e-3 :\n",
    "        break\n",
    "\n",
    "print(\"the number of epoches :\", epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f56983a3470>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "x = [random.randint(-200, 200) * 0.01 for i in range(20)]\n",
    "x = np.array(sorted(x))\n",
    "feature_x, y = get_batch(random = torch.from_numpy(x).float())\n",
    "y = y.data.numpy()\n",
    "plt.plot(x, y, 'ro', label='Original data')\n",
    "\n",
    "poly.eval()\n",
    "x_sample = np.arange(-2, 2, 0.01)\n",
    "x, y = get_batch(random = torch.from_numpy(x_sample).float())\n",
    "y = poly(x)\n",
    "y_sample = y.data.numpy()\n",
    "plt.plot(x_sample, y_sample, label = 'Fitting Line')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predicted function : y = 8.05 * x^4 + 7.44 * x^3 + -7.49 * x^2 + 9.16 * x^1 + 4.89\n",
      "real      function : y = 8.06 * x^4 + 7.44 * x^3 + -7.54 * x^2 + 9.18 * x^1 + 4.94\n"
     ]
    }
   ],
   "source": [
    "# 定义函数输出形式\n",
    "def func_format(weight, bias, n) :\n",
    "    func = ''\n",
    "    for i in range(n, 0, -1) :\n",
    "        func += ' {:.2f} * x^{} +'.format(weight[i - 1], i)\n",
    "    return 'y =' + func + ' {:.2f}'.format(bias[0])\n",
    "    \n",
    "predict_weight = poly.poly.weight.data.numpy().flatten()\n",
    "predict_bias = poly.poly.bias.data.numpy().flatten()\n",
    "print('predicted function :', func_format(predict_weight, predict_bias, n))\n",
    "real_W = W_target.numpy().flatten()\n",
    "real_b = b_target.numpy().flatten()\n",
    "print('real      function :', func_format(real_W, real_b, n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
